{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FinBERT Example Notebook\n",
    "\n",
    "This notebooks shows how to train and use the FinBERT pre-trained language model for financial sentiment analysis."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modules "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:55:04.902740Z",
     "start_time": "2020-03-23T15:55:04.876252Z"
    }
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import shutil\n",
    "import os\n",
    "import logging\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from textblob import TextBlob\n",
    "from pprint import pprint\n",
    "from sklearn.metrics import classification_report\n",
    "\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "from finbert.finbert import *\n",
    "import finbert.utils as tools\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "project_dir = Path.cwd().parent\n",
    "pd.set_option('max_colwidth', -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:55:05.711210Z",
     "start_time": "2020-03-23T15:55:05.693609Z"
    }
   },
   "outputs": [],
   "source": [
    "logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',\n",
    "                    datefmt = '%m/%d/%Y %H:%M:%S',\n",
    "                    level = logging.ERROR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting path variables:\n",
    "1. `lm_path`: the path for the pre-trained language model (If vanilla Bert is used then no need to set this one).\n",
    "2. `cl_path`: the path where the classification model is saved.\n",
    "3. `cl_data_path`: the path of the directory that contains the data files of `train.csv`, `validation.csv`, `test.csv`.\n",
    "---\n",
    "\n",
    "In the initialization of `bertmodel`, we can either use the original pre-trained weights from Google by giving `bm = 'bert-base-uncased`, or our further pre-trained language model by `bm = lm_path`\n",
    "\n",
    "\n",
    "---\n",
    "All of the configurations with the model is controlled with the `config` variable. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:55:07.405597Z",
     "start_time": "2020-03-23T15:55:07.386378Z"
    }
   },
   "outputs": [],
   "source": [
    "lm_path = project_dir/'models'/'language_model'/'finbertTRC2'\n",
    "cl_path = project_dir/'models'/'classifier_model'/'finbert-sentiment'\n",
    "cl_data_path = project_dir/'data'/'sentiment_data'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Configuring training parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can find the explanations of the training parameters in the class docsctrings. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:55:12.378583Z",
     "start_time": "2020-03-23T15:55:09.196746Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at /home/ubuntu/finbert/models/language_model/finbertTRC2 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /home/ubuntu/finbert/models/language_model/finbertTRC2 and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "# Clean the cl_path\n",
    "try:\n",
    "    shutil.rmtree(cl_path) \n",
    "except:\n",
    "    pass\n",
    "\n",
    "bertmodel = AutoModelForSequenceClassification.from_pretrained(lm_path,cache_dir=None, num_labels=3)\n",
    "\n",
    "\n",
    "config = Config(   data_dir=cl_data_path,\n",
    "                   bert_model=bertmodel,\n",
    "                   num_train_epochs=4,\n",
    "                   model_dir=cl_path,\n",
    "                   max_seq_length = 48,\n",
    "                   train_batch_size = 32,\n",
    "                   learning_rate = 2e-5,\n",
    "                   output_mode='classification',\n",
    "                   warm_up_proportion=0.2,\n",
    "                   local_rank=-1,\n",
    "                   discriminate=True,\n",
    "                   gradual_unfreeze=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`finbert` is our main class that encapsulates all the functionality. The list of class labels should be given in the prepare_model method call with label_list parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:55:16.657078Z",
     "start_time": "2020-03-23T15:55:16.639644Z"
    }
   },
   "outputs": [],
   "source": [
    "finbert = FinBert(config)\n",
    "finbert.base_model = 'bert-base-uncased'\n",
    "finbert.config.discriminate=True\n",
    "finbert.config.gradual_unfreeze=True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:55:17.850734Z",
     "start_time": "2020-03-23T15:55:17.368073Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12/24/2020 12:45:20 - INFO - finbert.finbert -   device: cuda n_gpu: 1, distributed training: False, 16-bits training: False\n"
     ]
    }
   ],
   "source": [
    "finbert.prepare_model(label_list=['positive','negative','neutral'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-tune the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:55:19.395707Z",
     "start_time": "2020-03-23T15:55:19.349642Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Get the training examples\n",
    "train_data = finbert.get_data('train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:55:25.912424Z",
     "start_time": "2020-03-23T15:55:20.065887Z"
    }
   },
   "outputs": [],
   "source": [
    "model = finbert.create_the_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Optional] Fine-tune only a subset of the model\n",
    "The variable `freeze` determines the last layer (out of 12) to be freezed. You can skip this part if you want to fine-tune the whole model.\n",
    "\n",
    "<span style=\"color:red\">Important: </span>\n",
    "Execute this step if you want a shorter training time in the expense of accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is for fine-tuning a subset of the model.\n",
    "\n",
    "freeze = 6\n",
    "\n",
    "for param in model.bert.embeddings.parameters():\n",
    "    param.requires_grad = False\n",
    "    \n",
    "for i in range(freeze):\n",
    "    for param in model.bert.encoder.layer[i].parameters():\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:58:35.486890Z",
     "start_time": "2020-03-23T15:55:27.293772Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12/24/2020 12:45:23 - INFO - finbert.utils -   *** Example ***\n",
      "12/24/2020 12:45:23 - INFO - finbert.utils -   guid: train-1\n",
      "12/24/2020 12:45:23 - INFO - finbert.utils -   tokens: [CLS] after the reporting period , bio ##tie north american licensing partner so ##max ##on pharmaceuticals announced positive results with na ##lm ##efe ##ne in a pilot phase 2 clinical trial for smoking ce ##ssa ##tion [SEP]\n",
      "12/24/2020 12:45:23 - INFO - finbert.utils -   input_ids: 101 2044 1996 7316 2558 1010 16012 9515 2167 2137 13202 4256 2061 17848 2239 24797 2623 3893 3463 2007 6583 13728 27235 2638 1999 1037 4405 4403 1016 6612 3979 2005 9422 8292 11488 3508 102 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:45:23 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:45:23 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:45:23 - INFO - finbert.utils -   label: positive (id = 0)\n",
      "12/24/2020 12:45:24 - INFO - finbert.finbert -   ***** Loading data *****\n",
      "12/24/2020 12:45:24 - INFO - finbert.finbert -     Num examples = 3488\n",
      "12/24/2020 12:45:24 - INFO - finbert.finbert -     Batch size = 32\n",
      "12/24/2020 12:45:24 - INFO - finbert.finbert -     Num steps = 48\n",
      "Epoch:   0%|          | 0/4 [00:00<?, ?it/s]"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e82b9092f7214c8097f81fea553e8d20",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=109.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12/24/2020 12:45:28 - INFO - finbert.utils -   *** Example ***\n",
      "12/24/2020 12:45:28 - INFO - finbert.utils -   guid: validation-1\n",
      "12/24/2020 12:45:28 - INFO - finbert.utils -   tokens: [CLS] our in - depth expertise extends to the fields of energy , industry , urban & mobility and water & environment [SEP]\n",
      "12/24/2020 12:45:28 - INFO - finbert.utils -   input_ids: 101 2256 1999 1011 5995 11532 8908 2000 1996 4249 1997 2943 1010 3068 1010 3923 1004 12969 1998 2300 1004 4044 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:45:28 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:45:28 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:45:28 - INFO - finbert.utils -   label: neutral (id = 2)\n",
      "12/24/2020 12:45:28 - INFO - finbert.finbert -   ***** Loading data *****\n",
      "12/24/2020 12:45:28 - INFO - finbert.finbert -     Num examples = 388\n",
      "12/24/2020 12:45:28 - INFO - finbert.finbert -     Batch size = 32\n",
      "12/24/2020 12:45:28 - INFO - finbert.finbert -     Num steps = 48\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6104bd00bad749c49895f5827e1e1926",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Validation losses: [0.8303732826159551]\n",
      "No best model found\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Epoch:  25%|██▌       | 1/4 [00:05<00:16,  5.65s/it]"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "29e0754cae2b410895cf2f4fce2c751b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=109.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12/24/2020 12:45:37 - INFO - finbert.utils -   *** Example ***\n",
      "12/24/2020 12:45:37 - INFO - finbert.utils -   guid: validation-1\n",
      "12/24/2020 12:45:37 - INFO - finbert.utils -   tokens: [CLS] our in - depth expertise extends to the fields of energy , industry , urban & mobility and water & environment [SEP]\n",
      "12/24/2020 12:45:37 - INFO - finbert.utils -   input_ids: 101 2256 1999 1011 5995 11532 8908 2000 1996 4249 1997 2943 1010 3068 1010 3923 1004 12969 1998 2300 1004 4044 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:45:37 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:45:37 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:45:37 - INFO - finbert.utils -   label: neutral (id = 2)\n",
      "12/24/2020 12:45:38 - INFO - finbert.finbert -   ***** Loading data *****\n",
      "12/24/2020 12:45:38 - INFO - finbert.finbert -     Num examples = 388\n",
      "12/24/2020 12:45:38 - INFO - finbert.finbert -     Batch size = 32\n",
      "12/24/2020 12:45:38 - INFO - finbert.finbert -     Num steps = 48\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "53efda63ce514236b96b1b840fb8d526",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Validation losses: [0.8303732826159551, 0.4003512469621805]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Epoch:  50%|█████     | 2/4 [00:14<00:13,  6.69s/it]"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3a63e57297ed4a2db1c16300941d3f5d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=109.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12/24/2020 12:45:50 - INFO - finbert.utils -   *** Example ***\n",
      "12/24/2020 12:45:50 - INFO - finbert.utils -   guid: validation-1\n",
      "12/24/2020 12:45:50 - INFO - finbert.utils -   tokens: [CLS] our in - depth expertise extends to the fields of energy , industry , urban & mobility and water & environment [SEP]\n",
      "12/24/2020 12:45:50 - INFO - finbert.utils -   input_ids: 101 2256 1999 1011 5995 11532 8908 2000 1996 4249 1997 2943 1010 3068 1010 3923 1004 12969 1998 2300 1004 4044 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:45:50 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:45:50 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:45:50 - INFO - finbert.utils -   label: neutral (id = 2)\n",
      "12/24/2020 12:45:50 - INFO - finbert.finbert -   ***** Loading data *****\n",
      "12/24/2020 12:45:50 - INFO - finbert.finbert -     Num examples = 388\n",
      "12/24/2020 12:45:50 - INFO - finbert.finbert -     Batch size = 32\n",
      "12/24/2020 12:45:50 - INFO - finbert.finbert -     Num steps = 48\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84777611fa484ab9869b87167368e2bb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Validation losses: [0.8303732826159551, 0.4003512469621805, 0.3436752695303697]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Epoch:  75%|███████▌  | 3/4 [00:27<00:08,  8.47s/it]"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9dd5c6d46aa94eb6a2ab8233de807855",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=109.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12/24/2020 12:46:04 - INFO - finbert.utils -   *** Example ***\n",
      "12/24/2020 12:46:04 - INFO - finbert.utils -   guid: validation-1\n",
      "12/24/2020 12:46:04 - INFO - finbert.utils -   tokens: [CLS] our in - depth expertise extends to the fields of energy , industry , urban & mobility and water & environment [SEP]\n",
      "12/24/2020 12:46:04 - INFO - finbert.utils -   input_ids: 101 2256 1999 1011 5995 11532 8908 2000 1996 4249 1997 2943 1010 3068 1010 3923 1004 12969 1998 2300 1004 4044 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:04 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:04 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:04 - INFO - finbert.utils -   label: neutral (id = 2)\n",
      "12/24/2020 12:46:04 - INFO - finbert.finbert -   ***** Loading data *****\n",
      "12/24/2020 12:46:04 - INFO - finbert.finbert -     Num examples = 388\n",
      "12/24/2020 12:46:04 - INFO - finbert.finbert -     Batch size = 32\n",
      "12/24/2020 12:46:04 - INFO - finbert.finbert -     Num steps = 48\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "39d1fcd5c14d4b81a74d812d9b150a84",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=0.0, max=13.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch: 100%|██████████| 4/4 [00:40<00:00, 10.20s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Validation losses: [0.8303732826159551, 0.4003512469621805, 0.3436752695303697, 0.34499044028612286]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "trained_model = finbert.train(train_examples = train_data, model = model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test the model\n",
    "\n",
    "`bert.evaluate` outputs the DataFrame, where true labels and logit values for each example is given"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:58:40.056789Z",
     "start_time": "2020-03-23T15:58:40.023198Z"
    }
   },
   "outputs": [],
   "source": [
    "test_data = finbert.get_data('test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:58:48.248044Z",
     "start_time": "2020-03-23T15:58:41.699009Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12/24/2020 12:46:05 - INFO - finbert.utils -   *** Example ***\n",
      "12/24/2020 12:46:05 - INFO - finbert.utils -   guid: test-1\n",
      "12/24/2020 12:46:05 - INFO - finbert.utils -   tokens: [CLS] the bristol port company has sealed a one million pound contract with cooper specialised handling to supply it with four 45 - ton ##ne , custom ##ised reach stack ##ers from ko ##ne ##cr ##ane ##s [SEP]\n",
      "12/24/2020 12:46:05 - INFO - finbert.utils -   input_ids: 101 1996 7067 3417 2194 2038 10203 1037 2028 2454 9044 3206 2007 6201 17009 8304 2000 4425 2009 2007 2176 3429 1011 10228 2638 1010 7661 5084 3362 9991 2545 2013 12849 2638 26775 7231 2015 102 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:05 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:05 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:05 - INFO - finbert.utils -   label: positive (id = 0)\n",
      "12/24/2020 12:46:06 - INFO - finbert.finbert -   ***** Loading data *****\n",
      "12/24/2020 12:46:06 - INFO - finbert.finbert -     Num examples = 970\n",
      "12/24/2020 12:46:06 - INFO - finbert.finbert -     Batch size = 32\n",
      "12/24/2020 12:46:06 - INFO - finbert.finbert -     Num steps = 120\n",
      "12/24/2020 12:46:06 - INFO - finbert.finbert -   ***** Running evaluation ***** \n",
      "12/24/2020 12:46:06 - INFO - finbert.finbert -     Num examples = 970\n",
      "12/24/2020 12:46:06 - INFO - finbert.finbert -     Batch size = 32\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9622f14f4530449b961f0450fd6805b6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Testing'), FloatProgress(value=0.0, max=31.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = finbert.evaluate(examples=test_data, model=trained_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare the classification report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:58:51.361079Z",
     "start_time": "2020-03-23T15:58:51.339548Z"
    }
   },
   "outputs": [],
   "source": [
    "def report(df, cols=['label','prediction','logits']):\n",
    "    #print('Validation loss:{0:.2f}'.format(metrics['best_validation_loss']))\n",
    "    cs = CrossEntropyLoss(weight=finbert.class_weights)\n",
    "    loss = cs(torch.tensor(list(df[cols[2]])),torch.tensor(list(df[cols[0]])))\n",
    "    print(\"Loss:{0:.2f}\".format(loss))\n",
    "    print(\"Accuracy:{0:.2f}\".format((df[cols[0]] == df[cols[1]]).sum() / df.shape[0]) )\n",
    "    print(\"\\nClassification Report:\")\n",
    "    print(classification_report(df[cols[0]], df[cols[1]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:58:53.190447Z",
     "start_time": "2020-03-23T15:58:53.166729Z"
    }
   },
   "outputs": [],
   "source": [
    "results['prediction'] = results.predictions.apply(lambda x: np.argmax(x,axis=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:58:54.436270Z",
     "start_time": "2020-03-23T15:58:54.399174Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss:0.39\n",
      "Accuracy:0.83\n",
      "\n",
      "Classification Report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.72      0.87      0.78       267\n",
      "           1       0.78      0.89      0.83       128\n",
      "           2       0.92      0.80      0.85       575\n",
      "\n",
      "    accuracy                           0.83       970\n",
      "   macro avg       0.80      0.85      0.82       970\n",
      "weighted avg       0.84      0.83      0.83       970\n",
      "\n"
     ]
    }
   ],
   "source": [
    "report(results,cols=['labels','prediction','predictions'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the `predict` function, given a piece of text, we split it into a list of sentences and then predict sentiment for each sentence. The output is written into a dataframe. Predictions are represented in three different columns: \n",
    "\n",
    "1) `logit`: probabilities for each class\n",
    "\n",
    "2) `prediction`: predicted label\n",
    "\n",
    "3) `sentiment_score`: sentiment score calculated as: probability of positive - probability of negative\n",
    "\n",
    "Below we analyze a paragraph taken out of [this](https://www.economist.com/finance-and-economics/2019/01/03/a-profit-warning-from-apple-jolts-markets) article from The Economist. For comparison purposes, we also put the sentiments predicted with TextBlob.\n",
    "> Later that day Apple said it was revising down its earnings expectations in the fourth quarter of 2018, largely because of lower sales and signs of economic weakness in China. The news rapidly infected financial markets. Apple’s share price fell by around 7% in after-hours trading and the decline was extended to more than 10% when the market opened. The dollar fell by 3.7% against the yen in a matter of minutes after the announcement, before rapidly recovering some ground. Asian stockmarkets closed down on January 3rd and European ones opened lower. Yields on government bonds fell as investors fled to the traditional haven in a market storm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:59:03.875213Z",
     "start_time": "2020-03-23T15:59:03.857612Z"
    }
   },
   "outputs": [],
   "source": [
    "text = \"Later that day Apple said it was revising down its earnings expectations in \\\n",
    "the fourth quarter of 2018, largely because of lower sales and signs of economic weakness in China. \\\n",
    "The news rapidly infected financial markets. Apple’s share price fell by around 7% in after-hours \\\n",
    "trading and the decline was extended to more than 10% when the market opened. The dollar fell \\\n",
    "by 3.7% against the yen in a matter of minutes after the announcement, before rapidly recovering \\\n",
    "some ground. Asian stockmarkets closed down on January 3rd and European ones opened lower. \\\n",
    "Yields on government bonds fell as investors fled to the traditional haven in a market storm.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:59:16.246963Z",
     "start_time": "2020-03-23T15:59:13.285393Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "cl_path = project_dir/'models'/'classifier_model'/'finbert-sentiment'\n",
    "model = AutoModelForSequenceClassification.from_pretrained(cl_path, cache_dir=None, num_labels=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import nltk\n",
    "nltk.download('punkt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:59:19.531663Z",
     "start_time": "2020-03-23T15:59:17.744984Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12/24/2020 12:46:11 - INFO - finbert.utils -   *** Example ***\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   guid: 0\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   tokens: [CLS] later that day apple said it was rev ##ising down its earnings expectations in the fourth quarter of 2018 , largely because of lower sales and signs of economic weakness in china . [SEP]\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   input_ids: 101 2101 2008 2154 6207 2056 2009 2001 7065 9355 2091 2049 16565 10908 1999 1996 2959 4284 1997 2760 1010 4321 2138 1997 2896 4341 1998 5751 1997 3171 11251 1999 2859 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   label: None (id = 9090)\n",
      "12/24/2020 12:46:11 - INFO - root -   tensor([[-1.4020,  2.8216, -1.4619],\n",
      "        [-0.8062,  2.0264, -0.9275],\n",
      "        [-1.3513,  2.9371, -1.5855],\n",
      "        [-0.0726,  2.2061, -2.1539],\n",
      "        [-1.4526,  2.9810, -1.3625]])\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   *** Example ***\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   guid: 0\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   tokens: [CLS] yields on government bonds fell as investors fled to the traditional haven in a market storm . [SEP]\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   input_ids: 101 16189 2006 2231 9547 3062 2004 9387 6783 2000 1996 3151 4033 1999 1037 3006 4040 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   label: None (id = 9090)\n",
      "12/24/2020 12:46:11 - INFO - root -   tensor([[-0.8754,  2.5811, -1.5627]])\n"
     ]
    }
   ],
   "source": [
    "result = predict(text,model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:59:20.519047Z",
     "start_time": "2020-03-23T15:59:20.440450Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentence</th>\n",
       "      <th>logit</th>\n",
       "      <th>prediction</th>\n",
       "      <th>sentiment_score</th>\n",
       "      <th>textblob_prediction</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Later that day Apple said it was revising down its earnings expectations in the fourth quarter of 2018, largely because of lower sales and signs of economic weakness in China.</td>\n",
       "      <td>[0.014240683, 0.9723466, 0.013412752]</td>\n",
       "      <td>negative</td>\n",
       "      <td>-0.958106</td>\n",
       "      <td>0.051746</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>The news rapidly infected financial markets.</td>\n",
       "      <td>[0.05297577, 0.9000978, 0.046926454]</td>\n",
       "      <td>negative</td>\n",
       "      <td>-0.847122</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Apple’s share price fell by around 7% in after-hours trading and the decline was extended to more than 10% when the market opened.</td>\n",
       "      <td>[0.013397858, 0.9760023, 0.010599791]</td>\n",
       "      <td>negative</td>\n",
       "      <td>-0.962604</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>The dollar fell by 3.7% against the yen in a matter of minutes after the announcement, before rapidly recovering some ground.</td>\n",
       "      <td>[0.091838144, 0.89670366, 0.011458234]</td>\n",
       "      <td>negative</td>\n",
       "      <td>-0.804866</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Asian stockmarkets closed down on January 3rd and European ones opened lower.</td>\n",
       "      <td>[0.011584295, 0.9757397, 0.012676031]</td>\n",
       "      <td>negative</td>\n",
       "      <td>-0.964155</td>\n",
       "      <td>-0.051111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Yields on government bonds fell as investors fled to the traditional haven in a market storm.</td>\n",
       "      <td>[0.0301131, 0.95474213, 0.015144793]</td>\n",
       "      <td>negative</td>\n",
       "      <td>-0.924629</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                                                                          sentence  \\\n",
       "0  Later that day Apple said it was revising down its earnings expectations in the fourth quarter of 2018, largely because of lower sales and signs of economic weakness in China.   \n",
       "1  The news rapidly infected financial markets.                                                                                                                                      \n",
       "2  Apple’s share price fell by around 7% in after-hours trading and the decline was extended to more than 10% when the market opened.                                                \n",
       "3  The dollar fell by 3.7% against the yen in a matter of minutes after the announcement, before rapidly recovering some ground.                                                     \n",
       "4  Asian stockmarkets closed down on January 3rd and European ones opened lower.                                                                                                     \n",
       "5  Yields on government bonds fell as investors fled to the traditional haven in a market storm.                                                                                     \n",
       "\n",
       "                                    logit prediction  sentiment_score  \\\n",
       "0  [0.014240683, 0.9723466, 0.013412752]   negative  -0.958106          \n",
       "1  [0.05297577, 0.9000978, 0.046926454]    negative  -0.847122          \n",
       "2  [0.013397858, 0.9760023, 0.010599791]   negative  -0.962604          \n",
       "3  [0.091838144, 0.89670366, 0.011458234]  negative  -0.804866          \n",
       "4  [0.011584295, 0.9757397, 0.012676031]   negative  -0.964155          \n",
       "5  [0.0301131, 0.95474213, 0.015144793]    negative  -0.924629          \n",
       "\n",
       "   textblob_prediction  \n",
       "0  0.051746             \n",
       "1  0.000000             \n",
       "2  0.500000             \n",
       "3  0.000000             \n",
       "4 -0.051111             \n",
       "5  0.000000             "
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blob = TextBlob(text)\n",
    "result['textblob_prediction'] = [sentence.sentiment.polarity for sentence in blob.sentences]\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:59:38.737969Z",
     "start_time": "2020-03-23T15:59:38.718255Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average sentiment is -0.91.\n"
     ]
    }
   ],
   "source": [
    "print(f'Average sentiment is %.2f.' % (result.sentiment_score.mean()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is another example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:59:45.922058Z",
     "start_time": "2020-03-23T15:59:45.904622Z"
    }
   },
   "outputs": [],
   "source": [
    "text2 = \"Shares in the spin-off of South African e-commerce group Naspers surged more than 25% \\\n",
    "in the first minutes of their market debut in Amsterdam on Wednesday. Bob van Dijk, CEO of \\\n",
    "Naspers and Prosus Group poses at Amsterdam's stock exchange, as Prosus begins trading on the \\\n",
    "Euronext stock exchange in Amsterdam, Netherlands, September 11, 2019. REUTERS/Piroschka van de Wouw \\\n",
    "Prosus comprises Naspers’ global empire of consumer internet assets, with the jewel in the crown a \\\n",
    "31% stake in Chinese tech titan Tencent. There is 'way more demand than is even available, so that’s \\\n",
    "good,' said the CEO of Euronext Amsterdam, Maurice van Tilburg. 'It’s going to be an interesting \\\n",
    "hour of trade after opening this morning.' Euronext had given an indicative price of 58.70 euros \\\n",
    "per share for Prosus, implying a market value of 95.3 billion euros ($105 billion). The shares \\\n",
    "jumped to 76 euros on opening and were trading at 75 euros at 0719 GMT.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:59:48.152474Z",
     "start_time": "2020-03-23T15:59:47.028417Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12/24/2020 12:46:11 - INFO - finbert.utils -   *** Example ***\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   guid: 0\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   tokens: [CLS] shares in the spin - off of south african e - commerce group nas ##pers surged more than 25 % in the first minutes of their market debut in amsterdam on wednesday . [SEP]\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   input_ids: 101 6661 1999 1996 6714 1011 2125 1997 2148 3060 1041 1011 6236 2177 17235 7347 18852 2062 2084 2423 1003 1999 1996 2034 2781 1997 2037 3006 2834 1999 7598 2006 9317 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:11 - INFO - finbert.utils -   label: None (id = 9090)\n",
      "12/24/2020 12:46:12 - INFO - root -   tensor([[ 2.0964, -1.7291, -0.9076],\n",
      "        [-0.2245, -1.8077,  2.8465],\n",
      "        [ 0.6664, -2.3775,  2.2567],\n",
      "        [ 2.1228, -2.1155, -0.4410],\n",
      "        [ 1.4156, -1.5475,  0.6892]])\n",
      "12/24/2020 12:46:12 - INFO - finbert.utils -   *** Example ***\n",
      "12/24/2020 12:46:12 - INFO - finbert.utils -   guid: 0\n",
      "12/24/2020 12:46:12 - INFO - finbert.utils -   tokens: [CLS] euro ##ne ##xt had given an indicative price of 58 . 70 euros per share for pro ##sus , implying a market value of 95 . 3 billion euros ( $ 105 billion ) . [SEP]\n",
      "12/24/2020 12:46:12 - INFO - finbert.utils -   input_ids: 101 9944 2638 18413 2018 2445 2019 24668 3976 1997 5388 1012 3963 19329 2566 3745 2005 4013 13203 1010 20242 1037 3006 3643 1997 5345 1012 1017 4551 19329 1006 1002 8746 4551 1007 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:12 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:12 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "12/24/2020 12:46:12 - INFO - finbert.utils -   label: None (id = 9090)\n",
      "12/24/2020 12:46:12 - INFO - root -   tensor([[ 0.4527, -2.1428,  2.0321],\n",
      "        [ 1.8974, -2.3414,  0.2543]])\n"
     ]
    }
   ],
   "source": [
    "result2 = predict(text2,model)\n",
    "blob = TextBlob(text2)\n",
    "result2['textblob_prediction'] = [sentence.sentiment.polarity for sentence in blob.sentences]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T15:59:50.428951Z",
     "start_time": "2020-03-23T15:59:50.402385Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentence</th>\n",
       "      <th>logit</th>\n",
       "      <th>prediction</th>\n",
       "      <th>sentiment_score</th>\n",
       "      <th>textblob_prediction</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Shares in the spin-off of South African e-commerce group Naspers surged more than 25% in the first minutes of their market debut in Amsterdam on Wednesday.</td>\n",
       "      <td>[0.93336153, 0.020355493, 0.046283066]</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.913006</td>\n",
       "      <td>0.250000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Bob van Dijk, CEO of Naspers and Prosus Group poses at Amsterdam's stock exchange, as Prosus begins trading on the Euronext stock exchange in Amsterdam, Netherlands, September 11, 2019.</td>\n",
       "      <td>[0.043918967, 0.009017709, 0.9470634]</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.034901</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>REUTERS/Piroschka van de Wouw Prosus comprises Naspers’ global empire of consumer internet assets, with the jewel in the crown a 31% stake in Chinese tech titan Tencent.</td>\n",
       "      <td>[0.16799231, 0.008004095, 0.8240036]</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.159988</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>There is 'way more demand than is even available, so that’s good,' said the CEO of Euronext Amsterdam, Maurice van Tilburg.</td>\n",
       "      <td>[0.9162205, 0.013223126, 0.07055632]</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.902997</td>\n",
       "      <td>0.533333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>'It’s going to be an interesting hour of trade after opening this morning.'</td>\n",
       "      <td>[0.65132874, 0.03364602, 0.3150252]</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.617683</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Euronext had given an indicative price of 58.70 euros per share for Prosus, implying a market value of 95.3 billion euros ($105 billion).</td>\n",
       "      <td>[0.1687288, 0.012589393, 0.81868184]</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.156139</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>The shares jumped to 76 euros on opening and were trading at 75 euros at 0719 GMT.</td>\n",
       "      <td>[0.8279447, 0.011943376, 0.16011195]</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.816001</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                                                                                    sentence  \\\n",
       "0  Shares in the spin-off of South African e-commerce group Naspers surged more than 25% in the first minutes of their market debut in Amsterdam on Wednesday.                                 \n",
       "1  Bob van Dijk, CEO of Naspers and Prosus Group poses at Amsterdam's stock exchange, as Prosus begins trading on the Euronext stock exchange in Amsterdam, Netherlands, September 11, 2019.   \n",
       "2  REUTERS/Piroschka van de Wouw Prosus comprises Naspers’ global empire of consumer internet assets, with the jewel in the crown a 31% stake in Chinese tech titan Tencent.                   \n",
       "3  There is 'way more demand than is even available, so that’s good,' said the CEO of Euronext Amsterdam, Maurice van Tilburg.                                                                 \n",
       "4  'It’s going to be an interesting hour of trade after opening this morning.'                                                                                                                 \n",
       "5  Euronext had given an indicative price of 58.70 euros per share for Prosus, implying a market value of 95.3 billion euros ($105 billion).                                                   \n",
       "6  The shares jumped to 76 euros on opening and were trading at 75 euros at 0719 GMT.                                                                                                          \n",
       "\n",
       "                                    logit prediction  sentiment_score  \\\n",
       "0  [0.93336153, 0.020355493, 0.046283066]  positive   0.913006          \n",
       "1  [0.043918967, 0.009017709, 0.9470634]   neutral    0.034901          \n",
       "2  [0.16799231, 0.008004095, 0.8240036]    neutral    0.159988          \n",
       "3  [0.9162205, 0.013223126, 0.07055632]    positive   0.902997          \n",
       "4  [0.65132874, 0.03364602, 0.3150252]     positive   0.617683          \n",
       "5  [0.1687288, 0.012589393, 0.81868184]    neutral    0.156139          \n",
       "6  [0.8279447, 0.011943376, 0.16011195]    positive   0.816001          \n",
       "\n",
       "   textblob_prediction  \n",
       "0  0.250000             \n",
       "1  0.000000             \n",
       "2  0.000000             \n",
       "3  0.533333             \n",
       "4  0.500000             \n",
       "5  0.000000             \n",
       "6  0.000000             "
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T16:00:27.031491Z",
     "start_time": "2020-03-23T16:00:27.012639Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average sentiment is 0.51.\n"
     ]
    }
   ],
   "source": [
    "print(f'Average sentiment is %.2f.' % (result2.sentiment_score.mean()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "FinBERT",
   "language": "python",
   "name": "finbert"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
